from plat.plat_train import get_train_info, plat_train_val, plat_trains

"""
比较隐藏层大小对最终结果的影响
隐藏层分别取16, 32, 64, 128
"""
if __name__ == '__main__':
    log_path = "../log/GatedGCN/hidden/"
    hidden_16 = "hidden_16.log"
    hidden_32 = "hidden_32.log"
    hidden_64 = "hidden_64.log"
    hidden_128 = "hidden_128.log"
    train_dict = dict()
    val_dict = dict()
    train_dict['hidden_16'], val_dict['hidden_16'] = get_train_info(log_path + hidden_16)
    train_dict['hidden_32'], val_dict['hidden_32'] = get_train_info(log_path + hidden_32)
    train_dict['hidden_64'], val_dict['hidden_64'] = get_train_info(log_path + hidden_64)
    train_dict['hidden_128'], val_dict['hidden_128'] = get_train_info(log_path + hidden_128)
    plat_train_val('hidden_16', train_dict['hidden_16'], val_dict['hidden_16'])
    plat_train_val('hidden_32', train_dict['hidden_32'], val_dict['hidden_32'])
    plat_train_val('hidden_64', train_dict['hidden_64'], val_dict['hidden_64'])
    plat_train_val('hidden_128', train_dict['hidden_128'], val_dict['hidden_128'])
    plat_trains(train_dict)
    plat_trains(val_dict)
